# -*- coding: utf-8 -*-
# Time    : 2024/5/10
# By      : Yang

import random
from typing import List, Tuple, Union
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader

from spm_tokenizer import CharTokenizer, SubwordTokenizer


class ASRDataset(Dataset):
    def __init__(self, wav_paths, wav_texts, wav_lengths, tokenizer: Union[CharTokenizer, SubwordTokenizer],
                 batch_size, batch_seconds, shuffle):
        """
            wav_paths: list of paths to wav files
            wav_text: list of texts
            wav_lengths: list of lengths (in seconds) of wav files
            tokenizer: tokenizer to convert char to tokens
            sos: start of sentence token
            eos: end of sentence token
            batch_size: batch size
            batch_seconds: batch length in seconds
            shuffle: whether to shuffle the dataset
        """
        assert len(wav_paths) == len(wav_texts) == len(wav_lengths)
        super(ASRDataset, self).__init__()
        self.wav_paths = wav_paths
        self.wav_lengths = wav_lengths
        self.wav_texts = wav_texts
        max_wav_length = 20  # maximum length of wav file in seconds
        self.samples = [(wav_paths[i], wav_texts[i], wav_lengths[i]) for i in range(len(wav_paths)) if
                        wav_lengths[i] <= max_wav_length]
        skipped = len(wav_paths) - len(self.samples)
        self.sr = 16000
        
        self.tokenizer = tokenizer
        self.sos = tokenizer.sos_id
        self.eos = tokenizer.eos_id
        
        self.batch_size = batch_size
        self.batch_seconds = batch_seconds
        self.is_shuffle = shuffle
        self.minibatches = []
        self.shuffle()
        
        print(f"ASRDataset: {len(self.samples)} samples, {len(self.minibatches)} mini-batches, "
              f"batch_size: {self.batch_size}, batch_seconds: {self.batch_seconds}, "
              f"max_wav_length: {max_wav_length}, skipped samples: {skipped}, "
              f"mean batch-size: {len(self.samples) / len(self.minibatches):.2f}")
    
    def init_mini_batches(self):
        """
        initialize mini-batches
        Code generated by Github Copilot
        """
        self.minibatches = []
        # sort samples by length if shuffle
        if self.is_shuffle:
            self.samples = sorted(self.samples, key=lambda x: x[2])
        # initialize mini-batches
        minibatch = []
        frames = 0
        for sample in self.samples:
            path, transcript, length = sample
            frames += length            # length is in seconds
            if frames > self.batch_seconds or len(minibatch) >= self.batch_size:
                self.minibatches.append(minibatch)
                minibatch = [sample]
                frames = length
            else:
                minibatch.append(sample)
        if minibatch:       # in case the last batch is not appended
            self.minibatches.append(minibatch)
    
    def shuffle(self):
        if self.is_shuffle:
            random.shuffle(self.samples)
        self.init_mini_batches()
        if self.is_shuffle:
            random.shuffle(self.minibatches)
    
    def __len__(self):
        return len(self.minibatches)
    
    def __getitem__(self, idx):
        """fetch a batch of data
        index: batch index
        returns:
            feat_data: (batch_size, seq_len, 80)
            feat_lens: (batch_size,)
            ys_in_pad: (batch_size, max_token_len)
            ys_out_pad: (batch_size, max_token_len)
        """
        fbank = []
        for sample in self.minibatches[idx]:
            path, text, length = sample
            print(path)
            wav, sr = torchaudio.load(path)
            assert sr == self.sr, f'sample rate mismatch: {sr} != {self.sr}'
            wav = wav * (1 << 15)  # rescale to int16 for kaldi compatibility
            fb = torchaudio.compliance.kaldi.fbank(wav, num_mel_bins=80)  # (seq_len,80)
            fbank.append(fb)
        ys_in, ys_out = [], []
        for sample in self.minibatches[idx]:
            path, text, length = sample
            tokens = self.tokenizer.tokenize(text)
            ys_in.append([self.sos] + tokens)
            ys_out.append(tokens + [self.eos])
        
        max_fbank_len = max([fb.shape[0] for fb in fbank])
        feat_data = torch.zeros(len(fbank), max_fbank_len, 80)
        for i, fb in enumerate(fbank):
            feat_data[i, :fb.shape[0]] == fb
            
        feat_lens = torch.tensor([fb.shape[0] for fb in fbank]).long()
        
        pad_token_for_ys_in = self.sos
        pad_token_for_ys_out = -1
        max_token_len = max([len(tokens) for tokens in ys_in])
        
        ys_in_pad = torch.ones(len(ys_in), max_token_len)
        ys_in_pad.fill_(pad_token_for_ys_in)
        
        ys_out_pad = torch.ones(len(ys_out), max_token_len)
        ys_out_pad.fill_(pad_token_for_ys_out)
        
        for i, tokens in enumerate(ys_in):
            ys_in_pad[i, :len(tokens)] = torch.tensor(tokens)
            
        for i, tokens in enumerate(ys_out):
            ys_out_pad[i, :len(tokens)] = torch.tensor(tokens)
        
        return feat_data, feat_lens, ys_in_pad.long(), ys_out_pad.long()


def get_dataloader(wav_paths, wav_lengths, wav_texts, tokenizer, batch_size, batch_seconds, shuffle):
    dataset = ASRDataset(wav_paths, wav_lengths, wav_texts, tokenizer, batch_size, batch_seconds, shuffle)
    dataloader = DataLoader(dataset, batch_size=None)
    return dataloader


if __name__ == "__main__":
    tokenizer = CharTokenizer()
    
    with open("../../../data/LRS2/train.paths") as f:
        wav_paths = f.read().splitlines()
    with open("../../../data/LRS2/train.text") as f:
        wav_texts = f.read().splitlines()
    with open("../../../data/LRS2/train.lengths") as f:
        wav_lengths = f.read().splitlines()
    
    wav_lengths = [float(length) for length in wav_lengths]
    batch_size = 32
    
    data_loader = get_dataloader(wav_paths, wav_texts, wav_lengths, tokenizer, batch_size, 32, shuffle=True)
    
    max_seq_len = 0
    for i, (feat_data, feat_lens, ys_in_pad, ys_out_pad) in enumerate(data_loader):
        print(
            f"i: {i:04d} -> fbank_feat: {feat_data.shape}, feat_lens: {feat_lens.shape}, "
            f"ys_in_pad: {ys_in_pad.shape}, ys_out_pad: {ys_out_pad.shape}"
        )
        max_seq_len = max(max_seq_len, feat_data.shape[1])
    print(f"max_seq_len: {max_seq_len}")
